-
Notifications
You must be signed in to change notification settings - Fork 2.4k
[WIP][BACKEND] Generalize the MemBar to consider cross-CTA ops #8834
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
lib/Analysis/Utility.cpp
Outdated
| if (auto blockArg = dyn_cast<BlockArgument>(cur)) { | ||
| auto yield = cast<scf::YieldOp>(blockArg.getOwner()->getTerminator()); | ||
| cur = yield.getOperand(blockArg.getArgNumber() - 1); | ||
| } else { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Loop-carried memdesc lookup underflows on non-for regions
The new findShmemAlloc assumes every memdesc BlockArgument is preceded by an induction variable and pulls the defining value with yield.getOperand(blockArg.getArgNumber() - 1). In scf.while/other regions whose block arguments are one-to-one with the yielded values, getArgNumber() can be 0, so subtracting 1 dereferences a negative index and trips an assertion/UB when membar analysis touches a memdesc carried through such a loop. This is a regression from the previous implementation (in TritonGPU/Transforms/Utility.cpp) which only subtracted one for scf.for blocks and otherwise handled block args directly.
Useful? React with 👍 / 👎.
ThomasRaoux
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if we could separate it from membar in order to decouple the logic
| auto writesUFDS = BlockInfo::UFDS(numCTAs); | ||
| return {readsUFDS, writesUFDS}; | ||
| } | ||
| } else if (auto tma = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
might be worth having an interface so that we don't need to add every op explicitly in membar
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can do it for tma and ops that take a memdesc explicitly (so they have effects) but for the other ops (just convert_layout and reduce) I have to do it manually, I think. I can do it manually for both of them in one go tho.
include/triton/Analysis/Membar.h
Outdated
| // Skip if filtered or both ops touch the same explicit shared | ||
| // allocation (same local_alloc). | ||
| return !((filter && filter(lhsOp, rhsOp)) || | ||
| (joined.isDistributed() && haveSameAlloc(lhsOp, rhsOp))); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we have to check for haveSameAlloc here? We should know that those intersect already.
I don't think it is safe to assume that we can always track back the alloc
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed the alloc tracking to a backwardSlice. I'd hope that this should then give us a pessimistic but correct analysis.
The context as to why we need to do this is in the OP.
Jokeren
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK I roughly understand the ideas here. Will wait for @lezcano to ping me once he considers it ready for view
lib/Analysis/Membar.cpp
Outdated
| OpBuilder::InsertionGuard g(*builder); | ||
| auto barrierOp = triton::gpu::LocalBarrierOp::create(*builder, op->getLoc()); | ||
| if (ctaClasses.isDistributed()) { | ||
| // TODO Insert a finer barrier when there is more than one CTA class |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does "finer" barrier mean here? Can you clarify?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The idea is that using mbars we can often do better synchronising all the CTAs.
include/triton/Analysis/Membar.h
Outdated
| struct BlockInfo { | ||
| using IntervalMapT = std::map<Interval<size_t>, std::set<Operation *>>; | ||
| // UFDS to represent cross-CTA reads/writes | ||
| struct UFDS { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might be better to make the name more explicit so that the struct name self-explains. e.g., CrossCTAUnionFindSet.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Going with CTA_UFDS to keep things tight. I will expand the name in the comment above the class tho
If we decouple it, we are going to generate strictly worse code. The thing here is that we'd first run the membar pass and put a |
The semantics here are that it's the user's/compiler's responsability to add the relevant synchronisation if they reuse the same shmem buffer, but otherwise the compiler will do so.
The semantics here are that it's the user's/compiler's responsability to
add the relevant synchronisation if they reuse the same shmem buffer,
but otherwise the compiler will do so.